import hashlib
import secrets
import time
from typing import Dict, Optional
from datetime import datetime, timedelta
import logging
from functools import wraps
from fastapi import HTTPException, status

logger = logging.getLogger(__name__)

class SecurityConfig:
    """安全配置类"""
    
    # 密码安全配置
    MIN_PASSWORD_LENGTH = 8
    MAX_PASSWORD_LENGTH = 128
    REQUIRE_UPPERCASE = True
    REQUIRE_LOWERCASE = True
    REQUIRE_DIGITS = True
    REQUIRE_SPECIAL_CHARS = True
    
    # 会话安全配置
    SESSION_TIMEOUT = 3600  # 1小时
    MAX_LOGIN_ATTEMPTS = 5
    LOCKOUT_DURATION = 900  # 15分钟
    
    # API安全配置
    MAX_REQUEST_SIZE = 10 * 1024 * 1024  # 10MB
    MAX_REQUESTS_PER_MINUTE = 60
    MAX_ADMIN_REQUESTS_PER_MINUTE = 30
    
    # 内容安全配置
    ALLOWED_FILE_TYPES = ['.json', '.txt']
    MAX_UPLOAD_SIZE = 5 * 1024 * 1024  # 5MB
    
    # 日志安全配置
    LOG_SENSITIVE_DATA = False
    MAX_LOG_SIZE = 100 * 1024 * 1024  # 100MB

class PasswordValidator:
    """密码验证器"""
    
    @staticmethod
    def validate_password(password: str) -> tuple[bool, str]:
        """验证密码强度"""
        if len(password) < SecurityConfig.MIN_PASSWORD_LENGTH:
            return False, f"密码长度至少{SecurityConfig.MIN_PASSWORD_LENGTH}位"
        
        if len(password) > SecurityConfig.MAX_PASSWORD_LENGTH:
            return False, f"密码长度不能超过{SecurityConfig.MAX_PASSWORD_LENGTH}位"
        
        if SecurityConfig.REQUIRE_UPPERCASE and not any(c.isupper() for c in password):
            return False, "密码必须包含大写字母"
        
        if SecurityConfig.REQUIRE_LOWERCASE and not any(c.islower() for c in password):
            return False, "密码必须包含小写字母"
        
        if SecurityConfig.REQUIRE_DIGITS and not any(c.isdigit() for c in password):
            return False, "密码必须包含数字"
        
        if SecurityConfig.REQUIRE_SPECIAL_CHARS and not any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in password):
            return False, "密码必须包含特殊字符"
        
        return True, "密码强度符合要求"
    
    @staticmethod
    def hash_password(password: str, salt: Optional[str] = None) -> tuple[str, str]:
        """哈希密码"""
        if salt is None:
            salt = secrets.token_hex(32)
        
        # 使用PBKDF2进行密码哈希
        password_hash = hashlib.pbkdf2_hmac(
            'sha256',
            password.encode('utf-8'),
            salt.encode('utf-8'),
            100000  # 迭代次数
        )
        
        return password_hash.hex(), salt
    
    @staticmethod
    def verify_password(password: str, password_hash: str, salt: str) -> bool:
        """验证密码"""
        computed_hash, _ = PasswordValidator.hash_password(password, salt)
        return secrets.compare_digest(computed_hash, password_hash)

class LoginAttemptTracker:
    """登录尝试跟踪器"""
    
    def __init__(self):
        self.attempts: Dict[str, list] = {}
        self.lockouts: Dict[str, datetime] = {}
    
    def record_attempt(self, identifier: str, success: bool) -> None:
        """记录登录尝试"""
        current_time = datetime.now()
        
        if identifier not in self.attempts:
            self.attempts[identifier] = []
        
        # 清理过期的尝试记录
        cutoff_time = current_time - timedelta(minutes=15)
        self.attempts[identifier] = [
            attempt for attempt in self.attempts[identifier]
            if attempt['time'] > cutoff_time
        ]
        
        # 记录新的尝试
        self.attempts[identifier].append({
            'time': current_time,
            'success': success
        })
        
        # 如果失败次数过多，锁定账户
        if not success:
            failed_attempts = [a for a in self.attempts[identifier] if not a['success']]
            if len(failed_attempts) >= SecurityConfig.MAX_LOGIN_ATTEMPTS:
                self.lockouts[identifier] = current_time + timedelta(seconds=SecurityConfig.LOCKOUT_DURATION)
                logger.warning(f"账户被锁定: {identifier}")
    
    def is_locked(self, identifier: str) -> bool:
        """检查账户是否被锁定"""
        if identifier in self.lockouts:
            if datetime.now() < self.lockouts[identifier]:
                return True
            else:
                # 锁定时间已过，移除锁定
                del self.lockouts[identifier]
        return False
    
    def get_remaining_lockout_time(self, identifier: str) -> Optional[int]:
        """获取剩余锁定时间（秒）"""
        if identifier in self.lockouts:
            remaining = (self.lockouts[identifier] - datetime.now()).total_seconds()
            return max(0, int(remaining))
        return None

class InputSanitizer:
    """输入清理器"""
    
    @staticmethod
    def sanitize_string(input_str: str, max_length: int = 1000) -> str:
        """清理字符串输入"""
        if not isinstance(input_str, str):
            raise ValueError("输入必须是字符串")
        
        # 移除控制字符
        sanitized = ''.join(char for char in input_str if ord(char) >= 32 or char in '\n\r\t')
        
        # 限制长度
        if len(sanitized) > max_length:
            sanitized = sanitized[:max_length]
        
        return sanitized.strip()
    
    @staticmethod
    def validate_filename(filename: str) -> bool:
        """验证文件名安全性"""
        if not filename:
            return False
        
        # 检查危险字符
        dangerous_chars = ['..', '/', '\\', ':', '*', '?', '"', '<', '>', '|']
        for char in dangerous_chars:
            if char in filename:
                return False
        
        # 检查文件扩展名
        if not any(filename.lower().endswith(ext) for ext in SecurityConfig.ALLOWED_FILE_TYPES):
            return False
        
        return True
    
    @staticmethod
    def validate_json_content(content: str) -> bool:
        """验证JSON内容安全性"""
        try:
            import json
            data = json.loads(content)
            
            # 检查嵌套深度
            def check_depth(obj, depth=0):
                if depth > 10:  # 最大嵌套深度
                    return False
                if isinstance(obj, dict):
                    return all(check_depth(v, depth + 1) for v in obj.values())
                elif isinstance(obj, list):
                    return all(check_depth(item, depth + 1) for item in obj)
                return True
            
            return check_depth(data)
        except (json.JSONDecodeError, RecursionError):
            return False

class SecurityHeaders:
    """安全头部管理器"""
    
    @staticmethod
    def get_security_headers() -> Dict[str, str]:
        """获取安全头部"""
        return {
            "X-Content-Type-Options": "nosniff",
            "X-Frame-Options": "DENY",
            "X-XSS-Protection": "1; mode=block",
            "Strict-Transport-Security": "max-age=31536000; includeSubDomains",
            "Content-Security-Policy": "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'",
            "Referrer-Policy": "strict-origin-when-cross-origin",
            "Permissions-Policy": "geolocation=(), microphone=(), camera=()"
        }

def security_audit_log(action: str, user: str, ip: str, details: str = ""):
    """安全审计日志"""
    timestamp = datetime.now().isoformat()
    log_entry = f"[SECURITY] {timestamp} | Action: {action} | User: {user} | IP: {ip}"
    if details:
        log_entry += f" | Details: {details}"
    
    logger.info(log_entry)

def require_secure_connection(func):
    """要求安全连接的装饰器"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        # 在生产环境中检查HTTPS
        # 这里可以添加HTTPS检查逻辑
        return func(*args, **kwargs)
    return wrapper

# 全局实例
login_tracker = LoginAttemptTracker()
input_sanitizer = InputSanitizer()
password_validator = PasswordValidator()